#include<iostream>
#include<vector>
#include<queue>
#include<cstring>
#include<cmath>
#include<map>
#include<set>
#include<cstdio>
#include<algorithm>
#define debug(a) cout<<#a<<"="<<a<<endl;
using namespace std;
const int maxn=5e5+100;
typedef long long LL;
const LL mod=998244353;
inline LL read(){LL x=0,f=1;char ch=getchar();	while (!isdigit(ch)){if (ch=='-') f=-1;ch=getchar();}while (isdigit(ch)){x=x*10+ch-48;ch=getchar();}
return x*f;}
LL num[maxn];
int main(void)
{
  cin.tie(0);std::ios::sync_with_stdio(false);
  LL n,m;cin>>n>>m;
  for(LL i=1;i<=n;i++){
    num[i]=( (n-i+1)%mod*i%mod)%mod;
  }
  ///for(LL i=1;i<=num[i];i++) cout<<num[i]<<" ";
  ///cout<<"\n";
  LL sum=0;
  for(LL i=1;i<=m;i++){
    sum=(sum%mod+(num[i]%mod*(i-1)%mod)%mod)%mod;
  }
  cout<<sum%mod<<"\n";
return 0;
}
